{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "tflite-transferlearning.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/lmoroney/tfbook/blob/master/chapter12/tflite-transferlearning.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zX4Kg8DUTKWO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "D1J15Vh_1Jih",
        "colab_type": "code",
        "cellView": "both",
        "colab": {}
      },
      "source": [
        "try:\n",
        "  # %tensorflow_version only exists in Colab.\n",
        "  %tensorflow_version 2.x\n",
        "except Exception:\n",
        "  pass\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sqNRQoc7W17k",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import numpy as np\n",
        "import matplotlib.pylab as plt\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow_hub as hub\n",
        "import tensorflow_datasets as tfds\n",
        "\n",
        "\n",
        "def format_image(image, label):\n",
        "    image = tf.image.resize(image, (224, 224)) / 255.0\n",
        "    return  image, label\n",
        "\n",
        "\n",
        "(raw_train, raw_validation, raw_test), metadata = tfds.load(\n",
        "    'cats_vs_dogs',\n",
        "    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],\n",
        "    with_info=True,\n",
        "    as_supervised=True,\n",
        ")\n",
        "\n",
        "num_examples = metadata.splits['train'].num_examples\n",
        "num_classes = metadata.features['label'].num_classes\n",
        "print(num_examples)\n",
        "print(num_classes)\n",
        "\n",
        "BATCH_SIZE = 32\n",
        "train_batches = raw_train.shuffle(num_examples // 4).map(format_image).batch(BATCH_SIZE).prefetch(1)\n",
        "validation_batches = raw_validation.map(format_image).batch(BATCH_SIZE).prefetch(1)\n",
        "test_batches = raw_test.map(format_image).batch(1)\n",
        "\n",
        "for image_batch, label_batch in train_batches.take(1):\n",
        "    pass\n",
        "\n",
        "image_batch.shape"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "inIK227eZbgV",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "module_selection = (\"mobilenet_v2\", 224, 1280) \n",
        "handle_base, pixels, FV_SIZE = module_selection\n",
        "MODULE_HANDLE =\"https://tfhub.dev/google/tf2-preview/{}/feature_vector/4\".format(handle_base)\n",
        "IMAGE_SIZE = (pixels, pixels)\n",
        "print(\"Using {} with input size {} and output dimension {}\".format(MODULE_HANDLE, IMAGE_SIZE, FV_SIZE))\n",
        "\n",
        "feature_extractor = hub.KerasLayer(MODULE_HANDLE,\n",
        "                                   input_shape=IMAGE_SIZE + (3,), \n",
        "                                   output_shape=[FV_SIZE],\n",
        "                                   trainable=False)\n",
        "\n",
        "print(\"Building model with\", MODULE_HANDLE)\n",
        "\n",
        "model = tf.keras.Sequential([\n",
        "        feature_extractor,\n",
        "        tf.keras.layers.Dense(num_classes, activation='softmax')\n",
        "])\n",
        "\n",
        "model.summary()\n",
        "\n",
        "model.compile(optimizer='adam',\n",
        "                  loss='sparse_categorical_crossentropy',\n",
        "                  metrics=['accuracy'])\n",
        "\n",
        "EPOCHS = 5\n",
        "\n",
        "hist = model.fit(train_batches,\n",
        "                 epochs=EPOCHS,\n",
        "                 validation_data=validation_batches)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cML-V4yqtuYI",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "CATS_VS_DOGS_SAVED_MODEL = \"exp_saved_model\"\n",
        "tf.saved_model.save(model, CATS_VS_DOGS_SAVED_MODEL)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gButZXqZt3o4",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "\n",
        "converter = tf.lite.TFLiteConverter.from_saved_model(CATS_VS_DOGS_SAVED_MODEL)\n",
        "\n",
        "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
        "\n",
        "def representative_data_gen():\n",
        "    for input_value, _ in test_batches.take(100):\n",
        "        yield [input_value]\n",
        "\n",
        "converter.representative_dataset = representative_data_gen\n",
        "converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]\n",
        "\n",
        "tflite_model = converter.convert()\n",
        "tflite_model_file = 'converted_model.tflite'\n",
        "\n",
        "with open(tflite_model_file, \"wb\") as f:\n",
        "    f.write(tflite_model)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uWipWGb0T_Gg",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "tflite_model_file = 'converted_model_withoptimizationsandquant.tflite'"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MsNjuPhxuDOx",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from tqdm import tqdm\n",
        "# Load TFLite model and allocate tensors.\n",
        "  \n",
        "interpreter = tf.lite.Interpreter(model_path=tflite_model_file)\n",
        "interpreter.allocate_tensors()\n",
        "\n",
        "input_index = interpreter.get_input_details()[0][\"index\"]\n",
        "output_index = interpreter.get_output_details()[0][\"index\"]\n",
        "\n",
        "predictions = []\n",
        "\n",
        "test_labels, test_imgs = [], []\n",
        "for img, label in tqdm(test_batches.take(100)):\n",
        "    interpreter.set_tensor(input_index, img)\n",
        "    interpreter.invoke()\n",
        "    predictions.append(interpreter.get_tensor(output_index))\n",
        "    \n",
        "    test_labels.append(label.numpy()[0])\n",
        "    test_imgs.append(img)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "b-bTxhk5S4HA",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "score = 0\n",
        "for item in range(0,99):\n",
        "  prediction=np.argmax(predictions[item])\n",
        "  label = test_labels[item]\n",
        "  if prediction==label:\n",
        "    score=score+1\n",
        "\n",
        "print(\"Out of 100 predictions I got \" + str(score) + \" correct\")\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gRgWsPmDuJEI",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "\n",
        "#@title Utility functions for plotting\n",
        "# Utilities for plotting\n",
        "\n",
        "class_names = ['cat', 'dog']\n",
        "\n",
        "def plot_image(i, predictions_array, true_label, img):\n",
        "    predictions_array, true_label, img = predictions_array[i], true_label[i], img[i]\n",
        "    plt.grid(False)\n",
        "    plt.xticks([])\n",
        "    plt.yticks([])\n",
        "    \n",
        "    img = np.squeeze(img)\n",
        "\n",
        "    plt.imshow(img, cmap=plt.cm.binary)\n",
        "    \n",
        "    predicted_label = np.argmax(predictions_array)\n",
        "    \n",
        "    if predicted_label == true_label:\n",
        "        color = 'green'\n",
        "    else:\n",
        "        color = 'red'\n",
        "    \n",
        "    plt.xlabel(\"{} {:2.0f}% ({})\".format(class_names[predicted_label],\n",
        "                                         100*np.max(predictions_array),\n",
        "                                         class_names[true_label]), color=color)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TylXFi0quLEw",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "\n",
        "#@title Visualize the outputs { run: \"auto\" }\n",
        "index = 73 #@param {type:\"slider\", min:0, max:99, step:1}\n",
        "for index in range(0,99):\n",
        "  plt.figure(figsize=(6,3))\n",
        "  plt.subplot(1,2,1)\n",
        "  plot_image(index, predictions, test_labels, test_imgs)\n",
        "  plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}